library(redist)
library(tidyverse)
library(maptools)
library(doMC)
library(igraph)
library(parallel)
source("../random_submap.R")

ben_theme <- function(){
    theme_classic() +
        theme(panel.background = element_blank(),
              panel.grid.major = element_blank(),
              panel.grid.minor = element_blank(),
              axis.line = element_line(colour = "black"),
              panel.border = element_rect(colour = "black", fill = NA, size = 1),
              strip.background = element_blank(),
              legend.position = "bottom", legend.title = element_blank(),
              plot.title = element_text(hjust = 0.5))
}
ipw <- function(x, beta, pop){
    indpop <- which(x$distance_parity <= pop)
    indbeta <- which(x$beta_sequence == beta)
    inds <- intersect(indpop, indbeta)
    psi <- x$constraint_pop[inds]
    w <- 1 / exp(beta * psi)
    inds <- sample(inds, length(inds), replace = TRUE, prob = w)
    x <- x$partitions[,inds]
    return(x)
}

## Load map
nh <- readShapePoly("../../data/nh/nh_final.shp")

## Set parameters
params <- expand.grid(
    nprecs = 25,
    ndists = 2,
    nsims = 1:200
)

appx_sims <- 5000000

dir.create("../../data/qq_test", showWarnings = FALSE)

## Get info on job
i <- as.numeric(Sys.getenv("SLURM_ARRAY_TASK_ID")) 

## -----------------
## Sim in each array
## -----------------

## Get parameters
## nprecs <- params$nprecs[i]
## ndists <- params$ndists[i]
## alg <- params$alg[i]
nprecs <- 25
ndists <- 2

## Sample map
nh_sub <- random_submap(nh, nprecs)
nh_sub@data$POP100 <- nh_sub@data$POP100 + 1
nh_sub@data$PRES_RVOTE <- nh_sub@data$PRES_RVOTE + 1

## Convert shp file to adjacency list
adjlist <- poly2nb(nh_sub, queen = FALSE)

## Sink
adjlist_map <- c()
for(k in 1:length(adjlist)){
    sub <- adjlist[[k]]
    sub <- sub[sub > k]
    if(length(sub) > 0){
        for(l in 1:length(sub)){
            adjlist_map <- rbind(adjlist_map, c(k, sub[l]))
        }
    }
}
write_delim(data.frame(adjlist_map),
            path = paste0("../../data/qq_test/adjlist_map_", i, ".dat"),
            col_names = FALSE)

## Order edges
system(paste0("python ../ndscut.py <../../data/qq_test/adjlist_map_", i, ".dat >../../data/qq_test/adjlist_map_ordered_", i, ".dat"))

## Run enumpart
system(paste0("../../enumpart_private/enumpart ../../data/qq_test/adjlist_map_ordered_", i,
              ".dat -k ", ndists, " -allsols >../../data/qq_test/adjlist_map_sols_", i, ".dat"))

## Load solutions, get true distribution
sols <- readLines(paste0("../../data/qq_test/adjlist_map_sols_", i, ".dat"))
el <- strsplit(readLines(paste0("../../data/qq_test/adjlist_map_ordered_", i, ".dat")), split = " ")
el <- apply(do.call(rbind, el), 2, as.numeric)
sols_out <- mclapply(1:length(sols), function(x){
    sol_split <- as.numeric(strsplit(sols[x], split = " ")[[1]])
    el_sub <- el[sol_split,]
    comps_out <- components(graph_from_edgelist(el_sub, directed = FALSE))
    if(comps_out$no < ndists){
        max_num <- max(comps_out$membership)
        comps_out <- c(comps_out$membership, (max_num + 1):ndists)
    }else{
        comps_out <- comps_out$membership
    }
    return(comps_out)
}, mc.cores = detectCores())
sols_out <- do.call(cbind, sols_out)

target <- redist.segcalc(sols_out, grouppop = nh_sub@data$PRES_RVOTE, 
                         fullpop = nh_sub@data$POP100)
popdist <- apply(sols_out, 2, function(x){
    distpop <- tapply(nh_sub@data$POP100, x, sum)
    parpop <- sum(distpop) / length(distpop)
    max(abs(distpop / parpop - 1))
})

## --------------------
## Run algorithm - MCMC
## --------------------
thin_chain <- function(x, thin = 1000){
    inds <- seq(1, ncol(x$partitions), by = thin)
    x_new <- vector(mode = "list", length = length(x))
    for(i in 1:length(x)){

        ## Subset the matrix first, then the vectors
        if(i == 1){
            x_new[[i]] <- x[[i]][,inds]
        }else{
            x_new[[i]] <- x[[i]][inds]
        }
        
    }
    names(x_new) <- names(x)
    class(x_new) <- "redist"
    return(x_new)
}

## No parity
mcmc_out <- redist.mcmc(nh_sub, nh_sub@data$POP100, nsims = appx_sims,
                        ndists = ndists,
                        maxiterrsg = 500000)
mcmc_seg_full <- redist.segcalc(thin_chain(mcmc_out),
                                grouppop = nh_sub@data$PRES_RVOTE, 
                                fullpop = nh_sub@data$POP100)

## 20% parity
mcmc_out <- redist.mcmc(nh_sub, nh_sub@data$POP100, nsims = appx_sims,
                        ndists = ndists, popcons = .2,
                        maxiterrsg = 500000)
mcmc_seg_20h <- redist.segcalc(thin_chain(mcmc_out),
                               grouppop = nh_sub@data$PRES_RVOTE, 
                               fullpop = nh_sub@data$POP100)

mcmc_out <- redist.mcmc(nh_sub, nh_sub@data$POP100, nsims = appx_sims,
                        ndists = ndists, constraint = "population",
                        beta = -1,
                        maxiterrsg = 500000)
mcmc_out <- ipw(thin_chain(mcmc_out), beta = -1, pop = .2)
mcmc_seg_20p <- redist.segcalc(mcmc_out, grouppop = nh_sub@data$PRES_RVOTE, 
                               fullpop = nh_sub@data$POP100)

betaweights <- rep(NA, 10); for(j in 1:10){betaweights[j] <- 2^j}
mcmc_out <- redist.mcmc(nh_sub, nh_sub@data$POP100, nsims = appx_sims,
                        ndists = ndists, constraint = "population",
                        temper = "simulated",
                        betaweights = betaweights,
                        beta = -1,
                        maxiterrsg = 500000)
mcmc_out <- ipw(thin_chain(mcmc_out), beta = -1, pop = .2)
mcmc_seg_20s <- redist.segcalc(mcmc_out, grouppop = nh_sub@data$PRES_RVOTE, 
                               fullpop = nh_sub@data$POP100)

## 10% parity
mcmc_out <- redist.mcmc(nh_sub, nh_sub@data$POP100, nsims = appx_sims,
                        ndists = ndists, popcons = .1,
                        maxiterrsg = 500000)
mcmc_seg_10h <- redist.segcalc(thin_chain(mcmc_out),
                               grouppop = nh_sub@data$PRES_RVOTE, 
                               fullpop = nh_sub@data$POP100)

mcmc_out <- redist.mcmc(nh_sub, nh_sub@data$POP100, nsims = appx_sims,
                        ndists = ndists, constraint = "population",
                        beta = -5,
                        maxiterrsg = 500000)
mcmc_out <- ipw(thin_chain(mcmc_out), beta = -5, pop = .1)
mcmc_seg_10p <- redist.segcalc(mcmc_out, grouppop = nh_sub@data$PRES_RVOTE, 
                               fullpop = nh_sub@data$POP100)

mcmc_out <- redist.mcmc(nh_sub, nh_sub@data$POP100, nsims = appx_sims,
                        ndists = ndists, constraint = "population",
                        temper = "simulated",
                        betaweights = betaweights,
                        beta = -5,
                        maxiterrsg = 500000)
mcmc_out <- ipw(thin_chain(mcmc_out), beta = -5, pop = .1)
mcmc_seg_10s <- redist.segcalc(mcmc_out, grouppop = nh_sub@data$PRES_RVOTE, 
                               fullpop = nh_sub@data$POP100)

## ## -------------------
## ## Run algorithm - RSG
## ## -------------------

## adjlist_zind <- lapply(adjlist, function(y){y-1})

## ## No parity
## sv_out <- mclapply(1:appx_sims, function(x){
##     if(x %% (appx_sims / 10) == 0){
##         print(paste0("Done with ", x, " no parity iterations at ", Sys.time(), "."))
##     }
##     redist.rsg(
##         adjlist_zind,
##         nh_sub@data$POP100, ndists = ndists, thresh = 100,
##         verbose = FALSE
##     )$district_membership
## }, mc.cores = detectCores())
## rsg_out <- do.call(cbind, sv_out)
## inds <- seq(1, ncol(rsg_out), by = 1000) ## Thin by same amount
## rsg_seg_full <- redist.segcalc(rsg_out[,inds], grouppop = nh_sub@data$PRES_RVOTE, 
##                                fullpop = nh_sub@data$POP100)

## ## 20% parity
## sv_out <- mclapply(1:appx_sims, function(x){
##     if(x %% (appx_sims / 10) == 0){
##         print(paste0("Done with ", x, " 20% iterations at ", Sys.time(), "."))
##     }
##     redist.rsg(
##         adjlist_zind,
##         nh_sub@data$POP100, ndists = ndists, thresh = .2,
##         maxiter = 500000, verbose = FALSE
##     )$district_membership
## }, mc.cores = detectCores())
## rsg_out <- do.call(cbind, sv_out)
## rsg_seg_20p <- redist.segcalc(rsg_out[,inds], grouppop = nh_sub@data$PRES_RVOTE, 
##                               fullpop = nh_sub@data$POP100)

## ## 10% parity
## sv_out <- mclapply(1:appx_sims, function(x){
##     if(x %% (appx_sims / 10) == 0){
##         print(paste0("Done with ", x, " 10% iterations at ", Sys.time(), "."))
##     }
##     redist.rsg(
##         adjlist_zind,
##         nh_sub@data$POP100, ndists = ndists, thresh = .1,
##         maxiter = 500000, verbose = FALSE
##     )$district_membership
## }, mc.cores = detectCores())
## rsg_out <- do.call(cbind, sv_out)
## rsg_seg_10p <- redist.segcalc(rsg_out[,inds], grouppop = nh_sub@data$PRES_RVOTE, 
##                               fullpop = nh_sub@data$POP100)

## -------------
## Create output
## -------------
out <- data.frame(
    ks_stat = c(
        ks.test(target, mcmc_seg_full)$statistic,
        ks.test(target[popdist <= .2], mcmc_seg_20h)$statistic,
        ks.test(target[popdist <= .2], mcmc_seg_20p)$statistic,
        ks.test(target[popdist <= .2], mcmc_seg_20s)$statistic,
        ks.test(target[popdist <= .1], mcmc_seg_10h)$statistic,
        ks.test(target[popdist <= .1], mcmc_seg_10p)$statistic,
        ks.test(target[popdist <= .1], mcmc_seg_10s)$statistic ## ,
        ## ks.test(target, rsg_seg_full)$statistic,
        ## ks.test(target[popdist <= .2], rsg_seg_20p)$statistic,
        ## ks.test(target[popdist <= .1], rsg_seg_10p)$statistic
    ),
    ks_pval = c(
        ks.test(target, mcmc_seg_full)$p.value,
        ks.test(target[popdist <= .2], mcmc_seg_20h)$p.value,
        ks.test(target[popdist <= .2], mcmc_seg_20p)$p.value,
        ks.test(target[popdist <= .2], mcmc_seg_20s)$p.value,
        ks.test(target[popdist <= .1], mcmc_seg_10h)$p.value,
        ks.test(target[popdist <= .1], mcmc_seg_10p)$p.value,
        ks.test(target[popdist <= .1], mcmc_seg_10s)$p.value ## ,
        ## ks.test(target, rsg_seg_full)$p.value,
        ## ks.test(target[popdist <= .2], rsg_seg_20p)$p.value,
        ## ks.test(target[popdist <= .1], rsg_seg_10p)$p.value
    )
)
out$nprecs <- nprecs
out$ndists <- paste0(ndists, " Districts")
out$alg <- c(
    "MCMC\n(Hard Constraint)",
    "MCMC\n(Hard Constraint)",
    "MCMC\n(Soft Constraint)",
    "MCMC\n(Tempering Constraint)",
    "MCMC\n(Hard Constraint)",
    "MCMC\n(Soft Constraint)",
    "MCMC\n(Tempering Constraint)" ## ,
    ## rep("RSG", 3)
)

out$parity <- c("No Equal Population Constraint",
                rep("20% Equal Population Constraint", 3),
                rep("10% Equal Population Constraint", 3)## ,
                ## "No Equal Population Constraint",
                ## "20% Equal Population Constraint",
                ## "10% Equal Population Constraint"
)

write_csv(out, path = paste0("../../data/qq_test/ks_test_", i, ".csv"))
